/*
 * Copyright 2006 Google Inc.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.google.gwt.user.server.rpc;

import com.google.gwt.user.client.rpc.RemoteService;
import com.google.gwt.user.client.rpc.SerializationException;
import com.google.gwt.user.server.rpc.impl.ServerSerializableTypeOracle;
import com.google.gwt.user.server.rpc.impl.ServerSerializableTypeOracleImpl;
import com.google.gwt.user.server.rpc.impl.ServerSerializationStreamReader;
import com.google.gwt.user.server.rpc.impl.ServerSerializationStreamWriter;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.zip.GZIPOutputStream;

import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;


/**
 * The servlet base class for your RPC service implementations that
 * automatically deserializes incoming requests from the client and serializes
 * outgoing responses for client/server RPCs.
 */
public class RemoteServiceServlet extends HttpServlet {

  /*
   * These members are used to get and set the different HttpServletResponse and
   * HttpServletRequest headers.
   */
  private static final String ACCEPT_ENCODING = "Accept-Encoding";
  private static final String CHARSET_UTF8 = "UTF-8";
  private static final String CONTENT_ENCODING = "Content-Encoding";
  private static final String CONTENT_ENCODING_GZIP = "gzip";
  private static final String CONTENT_TYPE_TEXT_PLAIN_UTF8 = "text/plain; charset=utf-8";
  private static final String GENERIC_FAILURE_MSG = "The call failed on the server; see server log for details";
  private static final HashMap TYPE_NAMES;

  /**
   * Controls the compression threshold at and below which no compression will
   * take place.
   */
  private static final int UNCOMPRESSED_BYTE_SIZE_LIMIT = 256;

  static {
    TYPE_NAMES = new HashMap();
    TYPE_NAMES.put("Z", boolean.class);
    TYPE_NAMES.put("B", byte.class);
    TYPE_NAMES.put("C", char.class);
    TYPE_NAMES.put("D", double.class);
    TYPE_NAMES.put("F", float.class);
    TYPE_NAMES.put("I", int.class);
    TYPE_NAMES.put("J", long.class);
    TYPE_NAMES.put("S", short.class);
  }

  /**
   * Return true if the response object accepts Gzip encoding. This is done by
   * checking that the accept-encoding header specifies gzip as a supported
   * encoding.
   */
  private static boolean acceptsGzipEncoding(HttpServletRequest request) {
    assert (request != null);

    String acceptEncoding = request.getHeader(ACCEPT_ENCODING);
    if (null == acceptEncoding) {
      return false;
    }

    return (acceptEncoding.indexOf(CONTENT_ENCODING_GZIP) != -1);
  }

  /**
   * This method attempts to estimate the number of bytes that a string will
   * consume when it is sent out as part of an HttpServletResponse. This really
   * a hack since we are assuming that every character will consume two bytes
   * upon transmission. This is definitely not true since some characters
   * actually consume more than two bytes and some consume less. This is even
   * less accurate if the string is converted to UTF8. However, it does save us
   * from converting every string that we plan on sending back to UTF8 just to
   * determine that we should not compress it.
   */
  private static int estimateByteSize(final String buffer) {
    return (buffer.length() * 2);
  }

  /**
   * Find the invoked method on either the specified interface or any super.
   */
  private static Method findInterfaceMethod(Class intf, String methodName,
      Class[] paramTypes, boolean includeInherited) {
    try {
      return intf.getDeclaredMethod(methodName, paramTypes);
    } catch (NoSuchMethodException e) {
      if (includeInherited) {
        Class[] superintfs = intf.getInterfaces();
        for (int i = 0; i < superintfs.length; i++) {
          Method method = findInterfaceMethod(superintfs[i], methodName,
              paramTypes, true);
          if (method != null) {
            return method;
          }
        }
      }

      return null;
    }
  }

  private final Set knownImplementedInterfaces = new HashSet();

  private final ThreadLocal perThreadRequest = new ThreadLocal();

  private final ThreadLocal perThreadResponse = new ThreadLocal();

  private final ThreadLocal perThreadInCall = new ThreadLocal();

  private final ServerSerializableTypeOracle serializableTypeOracle;

  /**
   * The default constructor.
   */
  public RemoteServiceServlet() {
    serializableTypeOracle = new ServerSerializableTypeOracleImpl(
        getPackagePaths());
  }

  /**
   * This is called internally.
   */
  public final void doPost(HttpServletRequest request,
      HttpServletResponse response) {
    Throwable caught;
    try {
      // Store the request & response objects in thread-local storage.
      //
      perThreadRequest.set(request);
      perThreadResponse.set(response);
      perThreadInCall.set(new Boolean(true));

      // Read the request fully.
      //
      String requestPayload = readPayloadAsUtf8(request);

      // Invoke the core dispatching logic, which returns the serialized
      // result.
      //
      String responsePayload = processCall(requestPayload);

      // Write the response.
      //
      writeResponse(request, response, responsePayload);
      return;
    } catch (IOException e) {
      caught = e;
    } catch (ServletException e) {
      caught = e;
    } catch (SerializationException e) {
      caught = e;
    } catch (Throwable e) {
      caught = e;
    } finally {
    	perThreadInCall.set(null);
    }

    respondWithFailure(response, caught);
  }

  /**
   * This is public so that it can be unit tested easily without HTTP.
   */
  public String processCall(String payload) throws SerializationException {

    // Let subclasses see the serialized request.
    //
    onBeforeRequestDeserialized(payload);

    // Create a stream to deserialize the request.
    //
    ServerSerializationStreamReader streamReader = new ServerSerializationStreamReader(
        serializableTypeOracle);
    streamReader.prepareToRead(payload);

    // Read the service interface
    //
    String serviceIntfName = streamReader.readString();

    // TODO(mmendez): need a way to check the type signature of the service intf
    // Verify that this very servlet implements the specified interface name.
    //
    if (!isImplementedRemoteServiceInterface(serviceIntfName)) {
      // Bad payload, possible hack attempt.
      //
      throw new SecurityException(
          "Blocked attempt to access interface '"
              + serviceIntfName
              + "', which is either not implemented by this servlet or which doesn't extend RemoteService; this is either misconfiguration or a hack attempt");
    }

    // Actually get the service interface, so that we can query its methods.
    //
    Class serviceIntf;
    try {
      serviceIntf = getClassFromName(serviceIntfName);
    } catch (ClassNotFoundException e) {
      throw new SerializationException("Unknown service interface class '"
          + serviceIntfName + "'", e);
    }

    // Read the method name.
    //
    String methodName = streamReader.readString();

    // Read the number and names of the parameter classes from the stream.
    // We have to do this so that we can find the correct overload of the
    // method.
    //
    int paramCount = streamReader.readInt();
    Class[] paramTypes = new Class[paramCount];
    for (int i = 0; i < paramTypes.length; i++) {
      String paramClassName = streamReader.readString();
      try {
        paramTypes[i] = getClassOrPrimitiveFromName(paramClassName);
      } catch (ClassNotFoundException e) {
        throw new SerializationException("Unknown parameter " + i + " type '"
            + paramClassName + "'", e);
      }
    }

    // For security, make sure the method is found in the service interface
    // and not just one that happens to be defined on this class.
    //
    Method serviceIntfMethod = findInterfaceMethod(serviceIntf, methodName,
        paramTypes, true);

    // If it wasn't found, don't continue.
    //
    if (serviceIntfMethod == null) {
      // Bad payload, possible hack attempt.
      //
      throw new SecurityException(
          "Method '"
              + methodName
              + "' (or a particular overload) on interface '"
              + serviceIntfName
              + "' was not found, this is either misconfiguration or a hack attempt");
    }

    // Deserialize the parameters.
    //
    Object[] args = new Object[paramCount];
    for (int i = 0; i < args.length; i++) {
      args[i] = streamReader.deserializeValue(paramTypes[i]);
    }

    // Make the call via reflection.
    //
    String responsePayload = GENERIC_FAILURE_MSG;
    ServerSerializationStreamWriter streamWriter = new ServerSerializationStreamWriter(
        serializableTypeOracle);
    Throwable caught = null;
    try {
      Class returnType = serviceIntfMethod.getReturnType();
      Object returnVal = serviceIntfMethod.invoke(this, args);
      responsePayload = createResponse(streamWriter, returnType, returnVal,
          false);
    } catch (IllegalArgumentException e) {
      caught = e;
    } catch (IllegalAccessException e) {
      caught = e;
    } catch (InvocationTargetException e) {
      // Try to serialize the caught exception if the client is expecting it,
      // otherwise log the exception server-side.
      caught = e;
      Throwable cause = e.getCause();
      if (cause != null) {
        // Update the caught exception to the underlying cause
        caught = cause;
        // Serialize the exception back to the client if it's a declared
        // exception
        if (isExpectedException(serviceIntfMethod, cause)) {
          Class thrownClass = cause.getClass();
          responsePayload = createResponse(streamWriter, thrownClass, cause,
              true);
          // Don't log the exception on the server
          caught = null;
        }
      }
    }

    if (caught != null) {
      responsePayload = GENERIC_FAILURE_MSG;
      ServletContext servletContext = getServletContext();
      // servletContext may be null (for example, when unit testing)
      if (servletContext != null) {
        // Log the exception server side
        servletContext.log("Exception while dispatching incoming RPC call",
            caught);
      }
    }

    // Let subclasses see the serialized response.
    //
    onAfterResponseSerialized(responsePayload);

    return responsePayload;
  }

  /**
   * Gets the <code>HttpServletRequest</code> object for the current call. It
   * is stored thread-locally so that simultaneous invocations can have
   * different request objects.
   */
  protected final HttpServletRequest getThreadLocalRequest() {
    return (HttpServletRequest) perThreadRequest.get();
  }

  /**
   * Gets the <code>HttpServletResponse</code> object for the current call. It
   * is stored thread-locally so that simultaneous invocations can have
   * different response objects.
   */
  protected final HttpServletResponse getThreadLocalResponse() {
    return (HttpServletResponse) perThreadResponse.get();
  }

  /**
   * Override this method to examine the serialized response that will be
   * returned to the client. The default implementation does nothing and need
   * not be called by subclasses.
   */
  protected void onAfterResponseSerialized(String serializedResponse) {
  }

  /**
   * Override this method to examine the serialized version of the request
   * payload before it is deserialized into objects. The default implementation
   * does nothing and need not be called by subclasses.
   */
  protected void onBeforeRequestDeserialized(String serializedRequest) {
  }

  /**
   * Determines whether the response to a given servlet request should or should
   * not be GZIP compressed. This method is only called in cases where the
   * requestor accepts GZIP encoding.
   * <p>
   * This implementation currently returns <code>true</code> if the response
   * string's estimated byte length is longer than 256 bytes. Subclasses can
   * override this logic.
   * </p>
   * 
   * @param request the request being served
   * @param response the response that will be written into
   * @param responsePayload the payload that is about to be sent to the client
   * @return <code>true</code> if responsePayload should be GZIP compressed,
   *         otherwise <code>false</code>.
   */
  protected boolean shouldCompressResponse(HttpServletRequest request,
      HttpServletResponse response, String responsePayload) {
    return estimateByteSize(responsePayload) > UNCOMPRESSED_BYTE_SIZE_LIMIT;
  }

  /**
   * @param stream
   * @param responseType
   * @param responseObj
   * @param isException
   * @return response
   */
  private String createResponse(ServerSerializationStreamWriter stream,
      Class responseType, Object responseObj, boolean isException) {
    stream.prepareToWrite();
    if (responseType != void.class) {
      try {
        stream.serializeValue(responseObj, responseType);
      } catch (SerializationException e) {
        responseObj = e;
        isException = true;
      }
    }

    String bufferStr = (isException ? "{EX}" : "{OK}") + stream.toString();
    return bufferStr;
  }

  /**
   * Returns the {@link Class} instance for the named class.
   * 
   * @param name the name of a class or primitive type
   * @return Class instance for the given type name
   * @throws ClassNotFoundException if the named type was not found
   */
  private Class getClassFromName(String name) throws ClassNotFoundException {
    return Class.forName(name, false, this.getClass().getClassLoader());
  }

  /**
   * Returns the {@link Class} instance for the named class or primitive type.
   * 
   * @param name the name of a class or primitive type
   * @return Class instance for the given type name
   * @throws ClassNotFoundException if the named type was not found
   */
  private Class getClassOrPrimitiveFromName(String name)
      throws ClassNotFoundException {
    Object value = TYPE_NAMES.get(name);
    if (value != null) {
      return (Class) value;
    }

    return getClassFromName(name);
  }

  /**
   * Obtain the special package-prefixes we use to check for custom serializers
   * that would like to live in a package that they cannot. For example,
   * "java.util.ArrayList" is in a sealed package, so instead we use this prefix
   * to check for a custom serializer in
   * "com.google.gwt.user.client.rpc.core.java.util.ArrayList". Right now, it's
   * hard-coded because we don't have a pressing need for this mechanism to be
   * extensible, but it is imaginable, which is why it's implemented this way.
   */
  private String[] getPackagePaths() {
    return new String[] {"com.google.gwt.user.client.rpc.core"};
  }

  /**
   * Returns true if the {@link java.lang.reflect.Method Method} definition on
   * the service is specified to throw the exception contained in the
   * InvocationTargetException or false otherwise. NOTE we do not check that the
   * type is serializable here. We assume that it must be otherwise the
   * application would never have been allowed to run.
   * 
   * @param serviceIntfMethod
   * @param e
   * @return is expected exception
   */
  private boolean isExpectedException(Method serviceIntfMethod, Throwable cause) {
    assert (serviceIntfMethod != null);
    assert (cause != null);

    Class[] exceptionsThrown = serviceIntfMethod.getExceptionTypes();
    if (exceptionsThrown.length <= 0) {
      // The method is not specified to throw any exceptions
      //
      return false;
    }

    Class causeType = cause.getClass();

    for (int index = 0; index < exceptionsThrown.length; ++index) {
      Class exceptionThrown = exceptionsThrown[index];
      assert (exceptionThrown != null);

      if (exceptionThrown.isAssignableFrom(causeType)) {
        return true;
      }
    }

    return false;
  }

  /**
   * Used to determine whether the specified interface name is implemented by
   * this class without loading the class (for security).
   */
  private boolean isImplementedRemoteServiceInterface(String intfName) {
    synchronized (knownImplementedInterfaces) {
      // See if it's cached.
      //
      if (knownImplementedInterfaces.contains(intfName)) {
        return true;
      }

      Class cls = getClass();

      // Unknown, so walk up the class hierarchy to find the first class that
      // implements the requested interface
      //
      while ((cls != null) && !RemoteServiceServlet.class.equals(cls)) {
        Class[] intfs = cls.getInterfaces();
        for (int i = 0; i < intfs.length; i++) {
          Class intf = intfs[i];
          if (isImplementedRemoteServiceInterfaceRecursive(intfName, intf)) {
            knownImplementedInterfaces.add(intfName);
            return true;
          }
        }

        // did not find the interface in this class so we look in the
        // superclass
        cls = cls.getSuperclass();
      }

      return false;
    }
  }

  /**
   * Only called from isImplementedInterface().
   */
  private boolean isImplementedRemoteServiceInterfaceRecursive(String intfName,
      Class intfToCheck) {
    assert (intfToCheck.isInterface());

    if (intfToCheck.getName().equals(intfName)) {
      // The name is right, but we also verify that it is assignable to
      // RemoteService.
      // 
      if (RemoteService.class.isAssignableFrom(intfToCheck)) {
        return true;
      } else {
        return false;
      }
    }

    Class[] intfs = intfToCheck.getInterfaces();
    for (int i = 0; i < intfs.length; i++) {
      Class intf = intfs[i];
      if (isImplementedRemoteServiceInterfaceRecursive(intfName, intf)) {
        return true;
      }
    }

    return false;
  }

  private String readPayloadAsUtf8(HttpServletRequest request)
      throws IOException, ServletException {
    int contentLength = request.getContentLength();
    if (contentLength == -1) {
      // Content length must be known.
      throw new ServletException("Content-Length must be specified");
    }

    String contentType = request.getContentType();
    boolean contentTypeIsOkay = false;
    // Content-Type must be specified.
    if (contentType != null) {
      // The type must be plain text.
      if (contentType.startsWith("text/plain")) {
        // And it must be UTF-8 encoded (or unspecified, in which case we assume
        // that it's either UTF-8 or ASCII).
        if (contentType.indexOf("charset=") == -1) {
          contentTypeIsOkay = true;
        } else if (contentType.indexOf("charset=utf-8") != -1) {
          contentTypeIsOkay = true;
        }
      }
    }
    if (!contentTypeIsOkay) {
      throw new ServletException(
          "Content-Type must be 'text/plain' with 'charset=utf-8' (or unspecified charset)");
    }
    InputStream in = request.getInputStream();
    try {
      byte[] payload = new byte[contentLength];
      int offset = 0;
      int len = contentLength;
      int byteCount;
      while (offset < contentLength) {
        byteCount = in.read(payload, offset, len);
        if (byteCount == -1) {
          throw new ServletException("Client did not send " + contentLength
              + " bytes as expected");
        }
        offset += byteCount;
        len -= byteCount;
      }
      return new String(payload, "UTF-8");
    } finally {
      if (in != null) {
        in.close();
      }
    }
  }

  /**
   * Called when the machinery of this class itself has a problem, rather than
   * the invoked third-party method. It writes a simple 500 message back to the
   * client.
   */
  private void respondWithFailure(HttpServletResponse response, Throwable caught) {
    ServletContext servletContext = getServletContext();
    servletContext.log("Exception while dispatching incoming RPC call", caught);
    try {
      response.setContentType("text/plain");
      response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
      response.getWriter().write(GENERIC_FAILURE_MSG);
    } catch (IOException e) {
      servletContext.log(
          "sendError() failed while sending the previous failure to the client",
          caught);
    }
  }

  private void writeResponse(HttpServletRequest request,
      HttpServletResponse response, String responsePayload) throws IOException {

    byte[] reply = responsePayload.getBytes(CHARSET_UTF8);
    String contentType = CONTENT_TYPE_TEXT_PLAIN_UTF8;

    if (acceptsGzipEncoding(request)
        && shouldCompressResponse(request, response, responsePayload)) {
      // Compress the reply and adjust headers.
      //
      ByteArrayOutputStream output = null;
      GZIPOutputStream gzipOutputStream = null;
      Throwable caught = null;
      try {
        output = new ByteArrayOutputStream(reply.length);
        gzipOutputStream = new GZIPOutputStream(output);
        gzipOutputStream.write(reply);
        gzipOutputStream.finish();
        gzipOutputStream.flush();
        response.setHeader(CONTENT_ENCODING, CONTENT_ENCODING_GZIP);
        reply = output.toByteArray();
      } catch (UnsupportedEncodingException e) {
        caught = e;
      } catch (IOException e) {
        caught = e;
      } finally {
        if (null != gzipOutputStream) {
          gzipOutputStream.close();
        }
        if (null != output) {
          output.close();
        }
      }

      if (caught != null) {
        getServletContext().log("Unable to compress response", caught);
        response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
        return;
      }
    }

    // Send the reply.
    //
    response.setContentLength(reply.length);
    response.setContentType(contentType);
    response.setStatus(HttpServletResponse.SC_OK);
    response.getOutputStream().write(reply);
  }
  
  protected final void setThreadLocalRequest(HttpServletRequest request) {
	  if (isThreadInCall()) {
		  throw new IllegalStateException("The request can only be set before the call starts");
	  }
	  perThreadRequest.set(request);
  }
  
  protected final void setThreadLocalResponse(HttpServletResponse response) {
	  if (isThreadInCall()) {
		  throw new IllegalStateException("The request can only be set before the call starts");
	  }
	  perThreadResponse.set(response);
  }

	private boolean isThreadInCall() {
		Boolean isThreadInCall = (Boolean) perThreadInCall.get();
		if (isThreadInCall != null) {
			return isThreadInCall.booleanValue();
		}
		return false;
	}

	protected final void setThreadInCall(boolean inCall) {
		perThreadInCall.set(new Boolean(inCall));
    }

}
